"""
#
# flow stability telegram
#
# Copyright (C) 2022 Alexandre Bovet <alexandre.bovet@math.uzh.ch>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import sys
import os
PACKAGE_PARENT = '../flow_stability'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

from FlowStability import SparseClustering
from TemporalNetwork import set_to_zeroes, inplace_csr_row_normalize
from SparseStochMat import sparse_autocov_mat
from parallel_clustering import compute_parallel_clustering, compute_parallel_nvi

import networkx as nx
import numpy as np
import pandas as pd
from scipy.sparse import diags, eye, csc_matrix
from parallel_expm import compute_parallel_expm

from argparse import ArgumentParser

#%%
ap = ArgumentParser()

ap.add_argument('--graph_file', type=str, required=True)

ap.add_argument('--graph_name', type=str, required=True)

inargs = vars(ap.parse_args())
graph_file = inargs['graph_file']
graph_name = inargs['graph_name']

print('Arguments:')
for item in inargs.items():
    print(item)

#%%

nproc = 12
num_repeat=50
num_norm_iter=10
trans_mat_tol=1e-9
cov_mat_tol=1e-6


times = np.logspace(-3,4,num=20)[2:-2]

# raise Exception

#%%

G = nx.read_gpickle(graph_file)

#%%
A = nx.adjacency_matrix(G, nodelist=sorted(G.nodes()))
out_degs = np.array(A.sum(1)).squeeze()
in_degs = np.array(A.sum(0)).squeeze()


f_nodes = np.where(out_degs > 0)[0]
b_nodes = np.where(in_degs > 0)[0]

#%%

def normalize_S_inplace(PT,p1,num_norm_iter=num_norm_iter):
    # normalize PT
    assert (np.array(PT.sum(1)).squeeze() >0 ).all()
    max_S_row_diff = np.abs(np.array(PT.sum(1)).squeeze() - p1).max()
    print('PID ', os.getpid(),
              f', max(row_sum|PT - pTp|) = {max_S_row_diff}, PT sum = {PT.data.sum()}') 
    
    # now normalize:
    for _ in range(num_norm_iter):
        inplace_csr_row_normalize(PT, p1)
        PT = (PT+PT.T)/2
    
    max_S_row_diff_norm = np.abs(np.array(PT.sum(1)).squeeze() - p1).max()

    print('PID ', os.getpid(),
          f', after normalization: max(row_sum|PT - pTp|) = {max_S_row_diff_norm}, PT sum = {PT.data.sum()}')    
    


#%% A forward remove zero_out_deg nodes

finished = False
while not finished:

    A_f = A[f_nodes,:][:,f_nodes]
    f_out_degs = np.array(A_f.sum(1)).squeeze()
    
    f_zero_out_degs_idx = np.where(f_out_degs == 0)[0]
    if f_zero_out_degs_idx.size == 0:
        finished = True
    else:
    #remove from initial set
        f_nodes = np.delete(f_nodes, f_zero_out_degs_idx)
        
    print(f_nodes.size)

#%% A backward 
A_rev = A.T

finished = False
while not finished:

    A_b = A_rev[b_nodes,:][:,b_nodes]
    b_out_degs = np.array(A_b.sum(1)).squeeze()
    
    b_zero_out_degs_idx = np.where(b_out_degs == 0)[0]
    if b_zero_out_degs_idx.size == 0:
        finished = True
    else:
    #remove from initial set
        b_nodes = np.delete(b_nodes,b_zero_out_degs_idx)
        
    print(b_out_degs.size)


#%% Laplacians

L_f = csc_matrix(eye(A_f.shape[0]) - diags(1/f_out_degs) @ A_f)

assert np.allclose(np.array(L_f.sum(1)).squeeze(),np.zeros_like(f_out_degs))

L_b = csc_matrix(eye(A_b.shape[0]) - diags(1/b_out_degs) @ A_b)

assert np.allclose(np.array(L_b.sum(1)).squeeze(),np.zeros_like(b_out_degs))

#%%



for t in times:
    print(t)

    T = compute_parallel_expm(-t*L_f, nproc=nproc, normalize_rows=False)
    
    assert np.allclose(np.array(T.sum(1)).squeeze(),np.ones(T.shape[0]))
    
    set_to_zeroes(T,tol=trans_mat_tol)
    inplace_csr_row_normalize(T)
    
    
    #%%
    
    T_rev = compute_parallel_expm(-t*L_b, nproc=nproc, normalize_rows=False)

    assert np.allclose(np.array(T_rev.sum(1)).squeeze(),np.ones(T_rev.shape[0]))
    
    set_to_zeroes(T_rev,tol=trans_mat_tol)
    inplace_csr_row_normalize(T_rev)
    
    #%%
    
    p1 = np.ones(T.shape[0])/T.shape[0]
    
    p2 = p1 @ T
    
    p2_for_inv = p2.copy()
    if (p2_for_inv == 0).sum() != 0:
        p2_for_inv[p2_for_inv==0] = 1
    
    T_inv = diags(1/p2_for_inv) @ T.T @ diags(p1)
    
    PT = diags(p1) @ T @ T_inv
    set_to_zeroes(PT,tol=cov_mat_tol)
    
    print('normalizing S_forw')
    normalize_S_inplace(PT,p1)
    
    assert np.allclose(PT.data.sum(),1)    
    
    Sforw = sparse_autocov_mat(PT=PT, p1=p1,p2=p1)
    
    
    #% backward
    p1_rev = np.ones(T_rev.shape[0])/T_rev.shape[0]
    
    p2_rev = p1_rev @ T_rev
    
    p2_rev_for_inv = p2_rev.copy()
    if (p2_rev_for_inv == 0).sum() != 0:
        p2_rev_for_inv[p2_rev_for_inv==0] = 1
    
    T_rev_inv = diags(1/p2_rev_for_inv) @ T_rev.T @ diags(p1_rev)
    
    PT_rev = diags(p1_rev) @ T_rev @ T_rev_inv
    
    set_to_zeroes(PT_rev,tol=cov_mat_tol)
    
    print('normalizing S_back')
    normalize_S_inplace(PT_rev,p1_rev)
    
    assert np.allclose(PT_rev.data.sum(),1)
    
    Sback_rev = sparse_autocov_mat(PT=PT_rev, p1=p1_rev,p2=p1_rev)
    
    pd.to_pickle({"T":T,
                  'T_rev':T_rev,
                  'Sforw':Sforw,
                  'Sback_rev':Sback_rev,
                  'forward_nodes' : f_nodes,
                  'backward_nodes' : b_nodes,
                  't':t}
                  ,f'data/xrw/xrw_flow_stab_Tmats_nokout_{graph_name}_{t:0.2e}.pickle', protocol=4)
    

    print('starting forward clustering')
    
    c_forw = SparseClustering(S=Sforw)
    forward_clusts, forward_stabs, forward_seeds = \
        compute_parallel_clustering(c_forw, num_repeat=num_repeat, nproc=nproc, 
                                verbose=True,
                                clust_verbose=False, print_num_loops=True)
    

    
    #%

    print('starting backward clustering')
    
    c_back = SparseClustering(S=Sback_rev)
    backward_clusts, backward_stabs, backward_seeds = \
        compute_parallel_clustering(c_back, num_repeat=num_repeat, nproc=nproc, 
                                verbose=True,
                                clust_verbose=False, print_num_loops=True)
        
        
    forward_nvis = compute_parallel_nvi(forward_clusts, Sforw.shape[0], 
                                        nproc=nproc, verbose=True)
    
    backward_nvis = compute_parallel_nvi(backward_clusts, Sback_rev.shape[0], 
                                        nproc=nproc, verbose=True)    
        
    print('saving for t=',t)
    
    pd.to_pickle({'forward_nodes': f_nodes,
                  'backward_nodes': b_nodes,
                  'forward_clusts' : forward_clusts,
                  'forward_stabs' : forward_stabs,
                  'backward_clusts' : backward_clusts,
                  'backward_stabs' : backward_stabs,
                  't':t,
                  'forward_seeds': forward_seeds,
                  'backward_seeds' : backward_seeds,
                  'num_repeat': num_repeat,
                  'forward_nvis': forward_nvis,
                  'backward_nvis' : backward_nvis},
                 f'data/xrw/xrw_flow_stab_clusts_nokout_{graph_name}_{t:0.2e}.pickle', protocol=4)
    
    